SigmoidCrossEntropyWithLogits
计算预测值与真实值之间的sigmoid交叉熵。
测量离散分类任务中的分布误差,每个类相互独立,且计算出各个类的交叉熵损失。
将输入 logits 设置为 \(X\),输入 label 为 \(Y\),输出为 \(loss\)。然后,
\[\begin{split}\begin{aligned}
p &= \text{sigmoid}(X) = \frac{1}{1 + e^{-X}} \\
loss &= -[Y \cdot \ln(p) + (1 - Y) \cdot \ln(1 - p)]
\end{aligned}\end{split}\]
- 输入:
input0 - 输入 logits 张量地址。
input1 - 输入标签张量地址,与 logits 形状相同。
length - 张量元素总数。
core_mask(int, 可选) - 核掩码(仅适用于共享存储版本)。
- 输出:
output - 输出损失张量地址,与输入张量形状相同。
- 支持平台:
FT78NEMT7004
备注
FT78NE 支持的数据类型:int8, fp32
MT7004 支持的数据类型:fp16, fp32
共享存储版本:
-
void i8_sigmoidcrossentropywithlogits_s(int8_t *input0, int8_t *input1, int8_t *output, int length, int core_mask)
-
void fp_sigmoidcrossentropywithlogits_s(float *input0, float *input1, float *output, int length, int core_mask)
-
void hp_sigmoidcrossentropywithlogits_s(half *input0, half *input1, half *output, int length, int core_mask)
C调用示例:
1// FT78NE 多核示例 2#include <stdio.h> 3#include <sigmoidcrossentropywithlogits.h> 4 5int main(int argc, char* argv[]) { 6 float *input0 = (float *)0xA0000000; // logits在DDR空间 7 float *input1 = (float *)0xB0000000; // label在DDR空间 8 float *output = (float *)0xC0000000; // 输出损失在DDR空间 9 int length = 1000; 10 int core_mask = 0xff; 11 12 // 计算 sigmoid 交叉熵损失 13 fp_sigmoidcrossentropywithlogits_s(input0, input1, output, length, core_mask); 14 return 0; 15}
私有存储版本:
-
void i8_sigmoidcrossentropywithlogits_p(int8_t *input0, int8_t *input1, int8_t *output, int length)
-
void fp_sigmoidcrossentropywithlogits_p(float *input0, float *input1, float *output, int length)
-
void hp_sigmoidcrossentropywithlogits_p(half *input0, half *input1, half *output, int length)
C调用示例:
1// MT7004 单核示例 2#include <stdio.h> 3#include <sigmoidcrossentropywithlogits.h> 4 5int main(int argc, char* argv[]) { 6 half *input0 = (half *)0x10000000; // logits在L2空间 7 half *input1 = (half *)0x10004000; // label在L2空间 8 half *output = (half *)0x10008000; // 输出损失在L2空间 9 int length = 1000; 10 11 // 计算 sigmoid 交叉熵损失 12 hp_sigmoidcrossentropywithlogits_p(input0, input1, output, length); 13 return 0; 14}